/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.test.operators;

import org.apache.flink.api.common.InvalidProgramException;
import org.apache.flink.api.common.functions.FilterFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapFunction;
import org.apache.flink.api.common.operators.Order;
import org.apache.flink.api.java.DataSet;
import org.apache.flink.api.java.ExecutionEnvironment;
import org.apache.flink.api.java.functions.KeySelector;
import org.apache.flink.api.java.operators.AggregateOperator;
import org.apache.flink.api.java.operators.DataSource;
import org.apache.flink.api.java.operators.DeltaIteration;
import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.tuple.Tuple3;
import org.apache.flink.test.operators.util.CollectionDataSets;
import org.apache.flink.test.operators.util.CollectionDataSets.POJO;
import org.apache.flink.test.util.MultipleProgramsTestBase;
import org.apache.flink.util.Collector;

import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.Parameterized;

import java.io.Serializable;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;

/** Integration tests for {@link MapPartitionFunction}. */
@RunWith(Parameterized.class)
@SuppressWarnings("serial")
public class PartitionITCase extends MultipleProgramsTestBase {

    public PartitionITCase(TestExecutionMode mode) {
        super(mode);
    }

    @Test
    public void testHashPartitionByKeyField() throws Exception {
        /*
         * Test hash partition by key field
         */

        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

        DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
        DataSet<Long> uniqLongs = ds.partitionByHash(1).mapPartition(new UniqueTupleLongMapper());
        List<Long> result = uniqLongs.collect();

        String expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n";

        compareResultAsText(result, expected);
    }

    @Test
    public void testRangePartitionByKeyField() throws Exception {
        /*
         * Test range partition by key field
         */

        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

        DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
        DataSet<Long> uniqLongs = ds.partitionByRange(1).mapPartition(new UniqueTupleLongMapper());
        List<Long> result = uniqLongs.collect();

        String expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n";

        compareResultAsText(result, expected);
    }

    @Test
    public void testHashPartitionByKeyField2() throws Exception {
        /*
         * Test hash partition by key field
         */

        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

        DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
        AggregateOperator<Tuple3<Integer, Long, String>> sum =
                ds.map(new PrefixMapper()).partitionByHash(1, 2).groupBy(1, 2).sum(0);

        List<Tuple3<Integer, Long, String>> result = sum.collect();

        String expected =
                "(1,1,Hi)\n"
                        + "(5,2,Hello)\n"
                        + "(4,3,Hello)\n"
                        + "(5,3,I am )\n"
                        + "(6,3,Luke )\n"
                        + "(34,4,Comme)\n"
                        + "(65,5,Comme)\n"
                        + "(111,6,Comme)";

        compareResultAsText(result, expected);
    }

    @Test
    public void testRangePartitionByKeyField2() throws Exception {
        /*
         * Test range partition by key field
         */

        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

        DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
        AggregateOperator<Tuple3<Integer, Long, String>> sum =
                ds.map(new PrefixMapper()).partitionByRange(1, 2).groupBy(1, 2).sum(0);

        List<Tuple3<Integer, Long, String>> result = sum.collect();

        String expected =
                "(1,1,Hi)\n"
                        + "(5,2,Hello)\n"
                        + "(4,3,Hello)\n"
                        + "(5,3,I am )\n"
                        + "(6,3,Luke )\n"
                        + "(34,4,Comme)\n"
                        + "(65,5,Comme)\n"
                        + "(111,6,Comme)";

        compareResultAsText(result, expected);
    }

    @Test
    public void testHashPartitionOfAtomicType() throws Exception {
        /*
         * Test hash partition of atomic type
         */

        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

        DataSet<Long> uniqLongs =
                env.generateSequence(1, 6)
                        .union(env.generateSequence(1, 6))
                        .rebalance()
                        .partitionByHash("*")
                        .mapPartition(new UniqueLongMapper());
        List<Long> result = uniqLongs.collect();

        String expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n";

        compareResultAsText(result, expected);
    }

    @Test
    public void testRangePartitionOfAtomicType() throws Exception {
        /*
         * Test range partition of atomic type
         */

        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

        DataSet<Long> uniqLongs =
                env.generateSequence(1, 6)
                        .union(env.generateSequence(1, 6))
                        .rebalance()
                        .partitionByRange("*")
                        .mapPartition(new UniqueLongMapper());
        List<Long> result = uniqLongs.collect();

        String expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n";

        compareResultAsText(result, expected);
    }

    @Test
    public void testHashPartitionByKeySelector() throws Exception {
        /*
         * Test hash partition by key selector
         */

        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

        DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
        DataSet<Long> uniqLongs =
                ds.partitionByHash(new KeySelector1()).mapPartition(new UniqueTupleLongMapper());
        List<Long> result = uniqLongs.collect();

        String expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n";

        compareResultAsText(result, expected);
    }

    private static class PrefixMapper
            implements MapFunction<Tuple3<Integer, Long, String>, Tuple3<Integer, Long, String>> {
        @Override
        public Tuple3<Integer, Long, String> map(Tuple3<Integer, Long, String> value)
                throws Exception {
            if (value.f2.length() > 5) {
                value.f2 = value.f2.substring(0, 5);
            }
            return value;
        }
    }

    @Test
    public void testRangePartitionByKeySelector() throws Exception {
        /*
         * Test range partition by key selector
         */

        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

        DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
        DataSet<Long> uniqLongs =
                ds.partitionByRange(new KeySelector1()).mapPartition(new UniqueTupleLongMapper());
        List<Long> result = uniqLongs.collect();

        String expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n";

        compareResultAsText(result, expected);
    }

    private static class KeySelector1 implements KeySelector<Tuple3<Integer, Long, String>, Long> {
        private static final long serialVersionUID = 1L;

        @Override
        public Long getKey(Tuple3<Integer, Long, String> value) throws Exception {
            return value.f1;
        }
    }

    @Test
    public void testForcedRebalancing() throws Exception {
        /*
         * Test forced rebalancing
         */

        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();

        // generate some number in parallel
        DataSet<Long> ds = env.generateSequence(1, 3000);
        DataSet<Tuple2<Integer, Integer>> uniqLongs =
                ds
                        // introduce some partition skew by filtering
                        .filter(new Filter1())
                        // rebalance
                        .rebalance()
                        // count values in each partition
                        .map(new PartitionIndexMapper())
                        .groupBy(0)
                        .reduce(new Reducer1())
                        // round counts to mitigate runtime scheduling effects (lazy split
                        // assignment)
                        .map(new Mapper1());

        List<Tuple2<Integer, Integer>> result = uniqLongs.collect();

        StringBuilder expected = new StringBuilder();
        int numPerPartition = 2220 / env.getParallelism() / 10;
        for (int i = 0; i < env.getParallelism(); i++) {
            expected.append('(').append(i).append(',').append(numPerPartition).append(")\n");
        }

        compareResultAsText(result, expected.toString());
    }

    private static class Filter1 implements FilterFunction<Long> {
        private static final long serialVersionUID = 1L;

        @Override
        public boolean filter(Long value) throws Exception {
            return value > 780;
        }
    }

    private static class Reducer1 implements ReduceFunction<Tuple2<Integer, Integer>> {
        private static final long serialVersionUID = 1L;

        @Override
        public Tuple2<Integer, Integer> reduce(
                Tuple2<Integer, Integer> v1, Tuple2<Integer, Integer> v2) {
            return new Tuple2<>(v1.f0, v1.f1 + v2.f1);
        }
    }

    private static class Mapper1
            implements MapFunction<Tuple2<Integer, Integer>, Tuple2<Integer, Integer>> {
        private static final long serialVersionUID = 1L;

        @Override
        public Tuple2<Integer, Integer> map(Tuple2<Integer, Integer> value) throws Exception {
            value.f1 = (value.f1 / 10);
            return value;
        }
    }

    @Test
    public void testHashPartitionByKeyFieldAndDifferentParallelism() throws Exception {
        /*
         * Test hash partition by key field and different parallelism
         */

        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(3);

        DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
        DataSet<Long> uniqLongs =
                ds.partitionByHash(1).setParallelism(4).mapPartition(new UniqueTupleLongMapper());
        List<Long> result = uniqLongs.collect();

        String expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n";

        compareResultAsText(result, expected);
    }

    @Test
    public void testRangePartitionByKeyFieldAndDifferentParallelism() throws Exception {
        /*
         * Test range partition by key field and different parallelism
         */

        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(3);

        DataSet<Tuple3<Integer, Long, String>> ds = CollectionDataSets.get3TupleDataSet(env);
        DataSet<Long> uniqLongs =
                ds.partitionByRange(1).setParallelism(4).mapPartition(new UniqueTupleLongMapper());
        List<Long> result = uniqLongs.collect();

        String expected = "1\n" + "2\n" + "3\n" + "4\n" + "5\n" + "6\n";

        compareResultAsText(result, expected);
    }

    @Test
    public void testHashPartitionWithKeyExpression() throws Exception {
        /*
         * Test hash partition with key expression
         */

        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(3);

        DataSet<POJO> ds = CollectionDataSets.getDuplicatePojoDataSet(env);
        DataSet<Long> uniqLongs =
                ds.partitionByHash("nestedPojo.longNumber")
                        .setParallelism(4)
                        .mapPartition(new UniqueNestedPojoLongMapper());
        List<Long> result = uniqLongs.collect();

        String expected = "10000\n" + "20000\n" + "30000\n";

        compareResultAsText(result, expected);
    }

    @Test
    public void testRangePartitionWithKeyExpression() throws Exception {
        /*
         * Test range partition with key expression
         */

        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        env.setParallelism(3);

        DataSet<POJO> ds = CollectionDataSets.getDuplicatePojoDataSet(env);
        DataSet<Long> uniqLongs =
                ds.partitionByRange("nestedPojo.longNumber")
                        .setParallelism(4)
                        .mapPartition(new UniqueNestedPojoLongMapper());
        List<Long> result = uniqLongs.collect();

        String expected = "10000\n" + "20000\n" + "30000\n";

        compareResultAsText(result, expected);
    }

    private static class UniqueTupleLongMapper
            implements MapPartitionFunction<Tuple3<Integer, Long, String>, Long> {
        private static final long serialVersionUID = 1L;

        @Override
        public void mapPartition(
                Iterable<Tuple3<Integer, Long, String>> records, Collector<Long> out)
                throws Exception {
            HashSet<Long> uniq = new HashSet<>();
            for (Tuple3<Integer, Long, String> t : records) {
                uniq.add(t.f1);
            }
            for (Long l : uniq) {
                out.collect(l);
            }
        }
    }

    private static class UniqueLongMapper implements MapPartitionFunction<Long, Long> {
        private static final long serialVersionUID = 1L;

        @Override
        public void mapPartition(Iterable<Long> longs, Collector<Long> out) throws Exception {
            HashSet<Long> uniq = new HashSet<>();
            for (Long l : longs) {
                uniq.add(l);
            }
            for (Long l : uniq) {
                out.collect(l);
            }
        }
    }

    private static class UniqueNestedPojoLongMapper implements MapPartitionFunction<POJO, Long> {
        private static final long serialVersionUID = 1L;

        @Override
        public void mapPartition(Iterable<POJO> records, Collector<Long> out) throws Exception {
            HashSet<Long> uniq = new HashSet<>();
            for (POJO t : records) {
                uniq.add(t.nestedPojo.longNumber);
            }
            for (Long l : uniq) {
                out.collect(l);
            }
        }
    }

    private static class PartitionIndexMapper
            extends RichMapFunction<Long, Tuple2<Integer, Integer>> {
        private static final long serialVersionUID = 1L;

        @Override
        public Tuple2<Integer, Integer> map(Long value) throws Exception {
            return new Tuple2<>(this.getRuntimeContext().getIndexOfThisSubtask(), 1);
        }
    }

    @Test
    public void testRangePartitionerOnSequenceData() throws Exception {
        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        DataSource<Long> dataSource = env.generateSequence(0, 10000);
        KeySelector<Long, Long> keyExtractor = new ObjectSelfKeySelector();

        MapPartitionFunction<Long, Tuple2<Long, Long>> minMaxSelector =
                new MinMaxSelector<>(new LongComparator(true));

        Comparator<Tuple2<Long, Long>> tuple2Comparator =
                new Tuple2Comparator(new LongComparator(true));

        List<Tuple2<Long, Long>> collected =
                dataSource.partitionByRange(keyExtractor).mapPartition(minMaxSelector).collect();
        Collections.sort(collected, tuple2Comparator);

        long previousMax = -1;
        for (Tuple2<Long, Long> tuple2 : collected) {
            if (previousMax == -1) {
                previousMax = tuple2.f1;
            } else {
                long currentMin = tuple2.f0;
                assertTrue(tuple2.f0 < tuple2.f1);
                assertEquals(previousMax + 1, currentMin);
                previousMax = tuple2.f1;
            }
        }
    }

    @Test(expected = InvalidProgramException.class)
    public void testRangePartitionInIteration() throws Exception {

        // does not apply for collection execution
        if (super.mode == TestExecutionMode.COLLECTION) {
            throw new InvalidProgramException("Does not apply for collection execution");
        }

        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        DataSource<Long> source = env.generateSequence(0, 10000);

        DataSet<Tuple2<Long, String>> tuples =
                source.map(
                        new MapFunction<Long, Tuple2<Long, String>>() {
                            @Override
                            public Tuple2<Long, String> map(Long v) throws Exception {
                                return new Tuple2<>(v, Long.toString(v));
                            }
                        });

        DeltaIteration<Tuple2<Long, String>, Tuple2<Long, String>> it =
                tuples.iterateDelta(tuples, 10, 0);
        DataSet<Tuple2<Long, String>> body =
                it.getWorkset()
                        .partitionByRange(
                                1) // Verify that range partition is not allowed in iteration
                        .join(it.getSolutionSet())
                        .where(0)
                        .equalTo(0)
                        .projectFirst(0)
                        .projectSecond(1);
        DataSet<Tuple2<Long, String>> result = it.closeWith(body, body);

        result.collect(); // should fail
    }

    @Test
    public void testRangePartitionerOnSequenceDataWithOrders() throws Exception {
        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        DataSet<Tuple2<Long, Long>> dataSet =
                env.generateSequence(0, 10000)
                        .map(
                                new MapFunction<Long, Tuple2<Long, Long>>() {
                                    @Override
                                    public Tuple2<Long, Long> map(Long value) throws Exception {
                                        return new Tuple2<>(value / 5000, value % 5000);
                                    }
                                });

        final Tuple2Comparator<Long> tuple2Comparator =
                new Tuple2Comparator<>(new LongComparator(true), new LongComparator(false));

        MinMaxSelector<Tuple2<Long, Long>> minMaxSelector = new MinMaxSelector<>(tuple2Comparator);

        final List<Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>>> collected =
                dataSet.partitionByRange(0, 1)
                        .withOrders(Order.ASCENDING, Order.DESCENDING)
                        .mapPartition(minMaxSelector)
                        .collect();

        Collections.sort(collected, new Tuple2Comparator<>(tuple2Comparator));

        Tuple2<Long, Long> previousMax = null;
        for (Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>> tuple2 : collected) {
            assertTrue(
                    "Min element in each partition should be smaller than max.",
                    tuple2Comparator.compare(tuple2.f0, tuple2.f1) <= 0);
            if (previousMax == null) {
                previousMax = tuple2.f1;
            } else {
                assertTrue(
                        "Partitions overlap. Previous max should be smaller than current min.",
                        tuple2Comparator.compare(previousMax, tuple2.f0) < 0);
                if (previousMax.f0.equals(tuple2.f0.f0)) {
                    // check that ordering on the second key is correct
                    assertEquals(
                            "Ordering on the second field should be continous.",
                            previousMax.f1 - 1,
                            tuple2.f0.f1.longValue());
                }
                previousMax = tuple2.f1;
            }
        }
    }

    @Test
    public void testRangePartitionerOnSequenceNestedDataWithOrders() throws Exception {
        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        final DataSet<Tuple2<Tuple2<Long, Long>, Long>> dataSet =
                env.generateSequence(0, 10000)
                        .map(
                                new MapFunction<Long, Tuple2<Tuple2<Long, Long>, Long>>() {
                                    @Override
                                    public Tuple2<Tuple2<Long, Long>, Long> map(Long value)
                                            throws Exception {
                                        return new Tuple2<>(
                                                new Tuple2<>(value / 5000, value % 5000), value);
                                    }
                                });

        final Tuple2Comparator<Long> tuple2Comparator =
                new Tuple2Comparator<>(new LongComparator(true), new LongComparator(true));
        MinMaxSelector<Tuple2<Long, Long>> minMaxSelector = new MinMaxSelector<>(tuple2Comparator);

        final List<Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>>> collected =
                dataSet.partitionByRange(0)
                        .withOrders(Order.ASCENDING)
                        .mapPartition(
                                new MapPartitionFunction<
                                        Tuple2<Tuple2<Long, Long>, Long>, Tuple2<Long, Long>>() {
                                    @Override
                                    public void mapPartition(
                                            Iterable<Tuple2<Tuple2<Long, Long>, Long>> values,
                                            Collector<Tuple2<Long, Long>> out)
                                            throws Exception {
                                        for (Tuple2<Tuple2<Long, Long>, Long> value : values) {
                                            out.collect(value.f0);
                                        }
                                    }
                                })
                        .mapPartition(minMaxSelector)
                        .collect();

        Collections.sort(collected, new Tuple2Comparator<>(tuple2Comparator));

        Tuple2<Long, Long> previousMax = null;
        for (Tuple2<Tuple2<Long, Long>, Tuple2<Long, Long>> tuple2 : collected) {
            assertTrue(
                    "Min element in each partition should be smaller than max.",
                    tuple2Comparator.compare(tuple2.f0, tuple2.f1) <= 0);
            if (previousMax == null) {
                previousMax = tuple2.f1;
            } else {
                assertTrue(
                        "Partitions overlap. Previous max should be smaller than current min.",
                        tuple2Comparator.compare(previousMax, tuple2.f0) < 0);
                if (previousMax.f0.equals(tuple2.f0.f0)) {
                    assertEquals(
                            "Ordering on the second field should be continous.",
                            previousMax.f1 + 1,
                            tuple2.f0.f1.longValue());
                }
                previousMax = tuple2.f1;
            }
        }
    }

    @Test
    public void testRangePartitionerWithKeySelectorOnSequenceNestedDataWithOrders()
            throws Exception {
        final ExecutionEnvironment env = ExecutionEnvironment.getExecutionEnvironment();
        final DataSet<Tuple2<ComparablePojo, Long>> dataSet =
                env.generateSequence(0, 10000)
                        .map(
                                new MapFunction<Long, Tuple2<ComparablePojo, Long>>() {
                                    @Override
                                    public Tuple2<ComparablePojo, Long> map(Long value)
                                            throws Exception {
                                        return new Tuple2<>(
                                                new ComparablePojo(value / 5000, value % 5000),
                                                value);
                                    }
                                });

        final List<Tuple2<ComparablePojo, ComparablePojo>> collected =
                dataSet.partitionByRange(
                                new KeySelector<Tuple2<ComparablePojo, Long>, ComparablePojo>() {
                                    @Override
                                    public ComparablePojo getKey(Tuple2<ComparablePojo, Long> value)
                                            throws Exception {
                                        return value.f0;
                                    }
                                })
                        .withOrders(Order.ASCENDING)
                        .mapPartition(new MinMaxSelector<>(new ComparablePojoComparator()))
                        .mapPartition(new ExtractComparablePojo())
                        .collect();

        final Comparator<Tuple2<ComparablePojo, ComparablePojo>> pojoComparator =
                new Comparator<Tuple2<ComparablePojo, ComparablePojo>>() {
                    @Override
                    public int compare(
                            Tuple2<ComparablePojo, ComparablePojo> o1,
                            Tuple2<ComparablePojo, ComparablePojo> o2) {
                        return o1.f0.compareTo(o2.f1);
                    }
                };
        Collections.sort(collected, pojoComparator);

        ComparablePojo previousMax = null;
        for (Tuple2<ComparablePojo, ComparablePojo> element : collected) {
            assertTrue(
                    "Min element in each partition should be smaller than max.",
                    element.f0.compareTo(element.f1) <= 0);
            if (previousMax == null) {
                previousMax = element.f1;
            } else {
                assertTrue(
                        "Partitions overlap. Previous max should be smaller than current min.",
                        previousMax.compareTo(element.f0) < 0);
                if (previousMax.first.equals(element.f0.first)) {
                    assertEquals(
                            "Ordering on the second field should be continous.",
                            previousMax.second - 1,
                            element.f0.second.longValue());
                }
                previousMax = element.f1;
            }
        }
    }

    private static class ExtractComparablePojo
            implements MapPartitionFunction<
                    Tuple2<Tuple2<ComparablePojo, Long>, Tuple2<ComparablePojo, Long>>,
                    Tuple2<ComparablePojo, ComparablePojo>> {

        @Override
        public void mapPartition(
                Iterable<Tuple2<Tuple2<ComparablePojo, Long>, Tuple2<ComparablePojo, Long>>> values,
                Collector<Tuple2<ComparablePojo, ComparablePojo>> out)
                throws Exception {
            for (Tuple2<Tuple2<ComparablePojo, Long>, Tuple2<ComparablePojo, Long>> value :
                    values) {
                out.collect(new Tuple2<>(value.f0.f0, value.f1.f0));
            }
        }
    }

    private static class ComparablePojoComparator
            implements Comparator<Tuple2<ComparablePojo, Long>>, Serializable {

        @Override
        public int compare(Tuple2<ComparablePojo, Long> o1, Tuple2<ComparablePojo, Long> o2) {
            return o1.f0.compareTo(o2.f0);
        }
    }

    private static class ComparablePojo implements Comparable<ComparablePojo> {
        private Long first;
        private Long second;

        public Long getFirst() {
            return first;
        }

        public void setFirst(Long first) {
            this.first = first;
        }

        public Long getSecond() {
            return second;
        }

        public void setSecond(Long second) {
            this.second = second;
        }

        public ComparablePojo(Long first, Long second) {
            this.first = first;
            this.second = second;
        }

        public ComparablePojo() {}

        @Override
        public int compareTo(ComparablePojo o) {
            final int firstResult = Long.compare(this.first, o.first);
            if (firstResult == 0) {
                return (-1) * Long.compare(this.second, o.second);
            }

            return firstResult;
        }
    }

    private static class ObjectSelfKeySelector implements KeySelector<Long, Long> {
        @Override
        public Long getKey(Long value) throws Exception {
            return value;
        }
    }

    private static class MinMaxSelector<T> implements MapPartitionFunction<T, Tuple2<T, T>> {

        private final Comparator<T> comparator;

        public MinMaxSelector(Comparator<T> comparator) {
            this.comparator = comparator;
        }

        @Override
        public void mapPartition(Iterable<T> values, Collector<Tuple2<T, T>> out) throws Exception {
            Iterator<T> itr = values.iterator();
            T min = itr.next();
            T max = min;
            T value;
            while (itr.hasNext()) {
                value = itr.next();
                if (comparator.compare(value, min) < 0) {
                    min = value;
                }
                if (comparator.compare(value, max) > 0) {
                    max = value;
                }
            }

            Tuple2<T, T> result = new Tuple2<>(min, max);
            out.collect(result);
        }
    }

    private static class Tuple2Comparator<T> implements Comparator<Tuple2<T, T>>, Serializable {

        private final Comparator<T> firstComparator;
        private final Comparator<T> secondComparator;

        public Tuple2Comparator(Comparator<T> comparator) {
            this(comparator, comparator);
        }

        public Tuple2Comparator(Comparator<T> firstComparator, Comparator<T> secondComparator) {
            this.firstComparator = firstComparator;
            this.secondComparator = secondComparator;
        }

        @Override
        public int compare(Tuple2<T, T> first, Tuple2<T, T> second) {
            long result = firstComparator.compare(first.f0, second.f0);
            if (result > 0) {
                return 1;
            } else if (result < 0) {
                return -1;
            }

            result = secondComparator.compare(first.f1, second.f1);
            if (result > 0) {
                return 1;
            } else if (result < 0) {
                return -1;
            }

            return 0;
        }
    }

    private static class LongComparator implements Comparator<Long>, Serializable {

        private final boolean ascending;

        public LongComparator(boolean ascending) {
            this.ascending = ascending;
        }

        @Override
        public int compare(Long o1, Long o2) {
            if (ascending) {
                return Long.compare(o1, o2);
            } else {
                return (-1) * Long.compare(o1, o2);
            }
        }
    }
}
